{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "ecad719c",
   "metadata": {},
   "source": [
    "# Using Weights & Biases with Tune\n",
    "\n",
    "(tune-wandb-ref)=\n",
    "\n",
    "[Weights & Biases](https://www.wandb.ai/) (Wandb) is a tool for experiment\n",
    "tracking, model optimizaton, and dataset versioning. It is very popular\n",
    "in the machine learning and data science community for its superb visualization\n",
    "tools.\n",
    "\n",
    "```{image} /images/wandb_logo_full.png\n",
    ":align: center\n",
    ":alt: Weights & Biases\n",
    ":height: 80px\n",
    ":target: https://www.wandb.ai/\n",
    "```\n",
    "\n",
    "Ray Tune currently offers two lightweight integrations for Weights & Biases.\n",
    "One is the {ref}`WandbLoggerCallback <tune-wandb-logger>`, which automatically logs\n",
    "metrics reported to Tune to the Wandb API.\n",
    "\n",
    "The other one is the {ref}`@wandb_mixin <tune-wandb-mixin>` decorator, which can be\n",
    "used with the function API. It automatically\n",
    "initializes the Wandb API with Tune's training information. You can just use the\n",
    "Wandb API like you would normally do, e.g. using `wandb.log()` to log your training\n",
    "process.\n",
    "\n",
    "```{contents}\n",
    ":backlinks: none\n",
    ":local: true\n",
    "```\n",
    "\n",
    "## Running A Weights & Biases Example\n",
    "\n",
    "In the following example we're going to use both of the above methods, namely the `WandbLoggerCallback` and\n",
    "the `wandb_mixin` decorator to log metrics.\n",
    "Let's start with a few crucial imports:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "100bcf8a",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import wandb\n",
    "\n",
    "from ray import tune\n",
    "from ray.tune import Trainable\n",
    "from ray.tune.integration.wandb import (\n",
    "    WandbLoggerCallback,\n",
    "    WandbTrainableMixin,\n",
    "    wandb_mixin,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "source": [
    "Next, let's define an easy `objective` function (a Tune `Trainable`) that reports a random loss to Tune.\n",
    "The objective function itself is not important for this example, since we want to focus on the Weights & Biases\n",
    "integration primarily."
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "def objective(config, checkpoint_dir=None):\n",
    "    for i in range(30):\n",
    "        loss = config[\"mean\"] + config[\"sd\"] * np.random.randn()\n",
    "        tune.report(loss=loss)"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "Given that you provide an `api_key_file` pointing to your Weights & Biases API key, you cna define a\n",
    "simple grid-search Tune run using the `WandbLoggerCallback` as follows:"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "def tune_function(api_key_file):\n",
    "    \"\"\"Example for using a WandbLoggerCallback with the function API\"\"\"\n",
    "    analysis = tune.run(\n",
    "        objective,\n",
    "        metric=\"loss\",\n",
    "        mode=\"min\",\n",
    "        config={\n",
    "            \"mean\": tune.grid_search([1, 2, 3, 4, 5]),\n",
    "            \"sd\": tune.uniform(0.2, 0.8),\n",
    "        },\n",
    "        callbacks=[\n",
    "            WandbLoggerCallback(api_key_file=api_key_file, project=\"Wandb_example\")\n",
    "        ],\n",
    "    )\n",
    "    return analysis.best_config"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "To use the `wandb_mixin` decorator, you can simply decorate the objective function from earlier.\n",
    "Note that we also use `wandb.log(...)` to log the `loss` to Weights & Biases as a dictionary.\n",
    "Otherwise, the decorated version of our objective is identical to its original."
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "@wandb_mixin\n",
    "def decorated_objective(config, checkpoint_dir=None):\n",
    "    for i in range(30):\n",
    "        loss = config[\"mean\"] + config[\"sd\"] * np.random.randn()\n",
    "        tune.report(loss=loss)\n",
    "        wandb.log(dict(loss=loss))"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "With the `decorated_objective` defined, running a Tune experiment is as simple as providing this objective and\n",
    "passing the `api_key_file` to the `wandb` key of your Tune `config`:"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "def tune_decorated(api_key_file):\n",
    "    \"\"\"Example for using the @wandb_mixin decorator with the function API\"\"\"\n",
    "    analysis = tune.run(\n",
    "        decorated_objective,\n",
    "        metric=\"loss\",\n",
    "        mode=\"min\",\n",
    "        config={\n",
    "            \"mean\": tune.grid_search([1, 2, 3, 4, 5]),\n",
    "            \"sd\": tune.uniform(0.2, 0.8),\n",
    "            \"wandb\": {\"api_key_file\": api_key_file, \"project\": \"Wandb_example\"},\n",
    "        },\n",
    "    )\n",
    "    return analysis.best_config"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "Finally, you can also define a class-based Tune `Trainable` by using the `WandbTrainableMixin` to define your objective:"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "class WandbTrainable(WandbTrainableMixin, Trainable):\n",
    "    def step(self):\n",
    "        for i in range(30):\n",
    "            loss = self.config[\"mean\"] + self.config[\"sd\"] * np.random.randn()\n",
    "            wandb.log({\"loss\": loss})\n",
    "        return {\"loss\": loss, \"done\": True}"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "Running Tune with this `WandbTrainable` works exactly the same as with the function API.\n",
    "The below `tune_trainable` function differs from `tune_decorated` above only in the first argument we pass to\n",
    "`tune.run()`:"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "def tune_trainable(api_key_file):\n",
    "    \"\"\"Example for using a WandTrainableMixin with the class API\"\"\"\n",
    "    analysis = tune.run(\n",
    "        WandbTrainable,\n",
    "        metric=\"loss\",\n",
    "        mode=\"min\",\n",
    "        config={\n",
    "            \"mean\": tune.grid_search([1, 2, 3, 4, 5]),\n",
    "            \"sd\": tune.uniform(0.2, 0.8),\n",
    "            \"wandb\": {\"api_key_file\": api_key_file, \"project\": \"Wandb_example\"},\n",
    "        },\n",
    "    )\n",
    "    return analysis.best_config"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "Since you may not have an API key for Wandb, we can _mock_ the Wandb logger and test all three of our training\n",
    "functions as follows.\n",
    "If you do have an API key file, make sure to set `mock_api` to `False` and pass in the right `api_key_file` below."
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "import tempfile\n",
    "from unittest.mock import MagicMock\n",
    "\n",
    "mock_api = True\n",
    "\n",
    "api_key_file = \"~/.wandb_api_key\"\n",
    "\n",
    "if mock_api:\n",
    "    WandbLoggerCallback._logger_process_cls = MagicMock\n",
    "    decorated_objective.__mixins__ = tuple()\n",
    "    WandbTrainable._wandb = MagicMock()\n",
    "    wandb = MagicMock()  # noqa: F811\n",
    "    temp_file = tempfile.NamedTemporaryFile()\n",
    "    temp_file.write(b\"1234\")\n",
    "    temp_file.flush()\n",
    "    api_key_file = temp_file.name\n",
    "\n",
    "tune_function(api_key_file)\n",
    "tune_decorated(api_key_file)\n",
    "tune_trainable(api_key_file)\n",
    "\n",
    "if mock_api:\n",
    "    temp_file.close()"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "id": "2f6e9138",
   "metadata": {},
   "source": [
    "This completes our Tune and Wandb walk-through.\n",
    "In the following sections you can find more details on the API of the Tune-Wandb integration.\n",
    "\n",
    "## Tune Wandb API Reference\n",
    "\n",
    "### WandbLoggerCallback\n",
    "\n",
    "(tune-wandb-logger)=\n",
    "\n",
    "```{eval-rst}\n",
    ".. autoclass:: ray.tune.integration.wandb.WandbLoggerCallback\n",
    "   :noindex:\n",
    "```\n",
    "\n",
    "### Wandb-Mixin\n",
    "\n",
    "(tune-wandb-mixin)=\n",
    "\n",
    "```{eval-rst}\n",
    ".. autofunction:: ray.tune.integration.wandb.wandb_mixin\n",
    "   :noindex:\n",
    "```"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "orphan": true
 },
 "nbformat": 4,
 "nbformat_minor": 5
}